from typing import *


class Solution:
    def countValidSelections(self, nums: List[int]) -> int:
        sm = sum(nums)
        tot, ans = 0, 0
        for i, v in enumerate(nums):
            if v == 0 and (diff := abs(sm - tot - tot)) < 2:
                ans += 2 - diff
            tot += v
        return ans
